{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## TensorFlow分布式训练\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "当我们拥有大量计算资源时，通过使用合适的分布式策略，我们可以充分利用这些计算资源，从而大幅压缩模型训练的时间。针对不同的使用场景，TensorFlow 在 tf.distribute.Strategy 中为我们提供了若干种分布式策略，使得我们能够更高效地训练模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 单机多卡训练： MirroredStrategy\n",
    "\n",
    "\n",
    "    tf.distribute.MirroredStrategy 是一种简单且高性能的，数据并行的同步式分布式策略，主要支持多个 GPU 在同一台主机上训练。使用这种策略时，我们只需实例化一个 MirroredStrategy 策略:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "#并将模型构建的代码放入 strategy.scope() 的上下文环境中:\n",
    "\n",
    "with strategy.scope():\n",
    "# 模型构建代码\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以在参数中指定设备，如:\n",
    "\n",
    "strategy = tf.distribute.MirroredStrategy(devices=[\"/gpu:0\", \"/gpu:1\"])\n",
    "\n",
    "\n",
    "即指定只使用第 0、1 号 GPU 参与分布式策略。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#以下代码展示了使用 MirroredStrategy 策略，在 TensorFlow Datasets 中的部分图像数据集上使用 Keras 训练 MobileNetV2 的过程：\n",
    "\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "num_epochs = 5\n",
    "batch_size_per_replica = 64\n",
    "learning_rate = 0.001\n",
    "\n",
    "strategy = tf.distribute.MirroredStrategy()\n",
    "print('Number of devices: %d' % strategy.num_replicas_in_sync)  # 输出设备数量\n",
    "batch_size = batch_size_per_replica * strategy.num_replicas_in_sync\n",
    "\n",
    "# 载入数据集并预处理\n",
    "def resize(image, label):\n",
    "    image = tf.image.resize(image, [224, 224]) / 255.0\n",
    "    return image, label\n",
    "\n",
    "# 使用 TensorFlow Datasets 载入猫狗分类数据集，详见“TensorFlow Datasets数据集载入”一章\n",
    "dataset = tfds.load(\"cats_vs_dogs\", split=tfds.Split.TRAIN, as_supervised=True)\n",
    "dataset = dataset.map(resize).shuffle(1024).batch(batch_size)\n",
    "\n",
    "with strategy.scope():\n",
    "    model = tf.keras.applications.MobileNetV2(weights=None, classes=2)\n",
    "    model.compile(\n",
    "        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),\n",
    "        loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
    "        metrics=[tf.keras.metrics.sparse_categorical_accuracy]\n",
    "    )\n",
    "\n",
    "model.fit(dataset, epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 多机训练： MultiWorkerMirroredStrategy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "多机训练的方法和单机多卡类似，将 MirroredStrategy 更换为适合多机训练的 MultiWorkerMirroredStrategy 即可。不过，由于涉及到多台计算机之间的通讯，还需要进行一些额外的设置。具体而言，需要设置环境变量 TF_CONFIG ，示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['TF_CONFIG'] = json.dumps({\n",
    "    'cluster': {\n",
    "        'worker': [\"localhost:20000\", \"localhost:20001\"]\n",
    "    },\n",
    "    'task': {'type': 'worker', 'index': 0}\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TF_CONFIG 由 cluster 和 task 两部分组成：\n",
    "\n",
    "      1. cluster 说明了整个多机集群的结构和每台机器的网络地址（IP + 端口号）。对于每一台机器，cluster 的值都是相同的；\n",
    "\n",
    "      2. task 说明了当前机器的角色。例如， {'type': 'worker', 'index': 0} 说明当前机器是 cluster 中的第 0 个 worker（即 localhost:20000 ）。每一台机器的 task 值都需要针对当前主机进行分别的设置。\n",
    "\n",
    "以上内容设置完成后，在所有的机器上逐个运行训练代码即可。先运行的代码在尚未与其他主机连接时会进入监听状态，待整个集群的连接建立完毕后，所有的机器即会同时开始训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "请在各台机器上均注意防火墙的设置，尤其是需要开放与其他主机通信的端口。如上例的 0 号\n",
    "worker 需要开放 20000 端口，1 号 worker 需要开放 20001 端口。\n",
    "\"\"\"\n",
    "\"\"\"\n",
    "\n",
    "以下示例的训练任务与前节相同，只不过迁移到了多机训练环境。假设我们有两台机器，即首先在两台机器上均部署下面的程序，\n",
    "唯一的区别是 task 部分，第一台机器设置为 {'type': 'worker', 'index': 0} ，\n",
    "第二台机器设置为 {'type': 'worker', 'index': 1} 。接下来，在两台机器上依次运行程序，待通讯成功后，即会自动开始训练流程。\"\"\"\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import os\n",
    "import json\n",
    "\n",
    "num_epochs = 5\n",
    "batch_size_per_replica = 64\n",
    "learning_rate = 0.001\n",
    "\n",
    "num_workers = 2\n",
    "os.environ['TF_CONFIG'] = json.dumps({\n",
    "    'cluster': {\n",
    "        'worker': [\"localhost:20000\", \"localhost:20001\"]\n",
    "    },\n",
    "    'task': {'type': 'worker', 'index': 0}\n",
    "})\n",
    "strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()\n",
    "batch_size = batch_size_per_replica * num_workers\n",
    "\n",
    "def resize(image, label):\n",
    "    image = tf.image.resize(image, [224, 224]) / 255.0\n",
    "    return image, label\n",
    "\n",
    "dataset = tfds.load(\"cats_vs_dogs\", split=tfds.Split.TRAIN, as_supervised=True)\n",
    "dataset = dataset.map(resize).shuffle(1024).batch(batch_size)\n",
    "\n",
    "with strategy.scope():\n",
    "    model = tf.keras.applications.MobileNetV2(weights=None, classes=2)\n",
    "    model.compile(\n",
    "        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),\n",
    "        loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
    "        metrics=[tf.keras.metrics.sparse_categorical_accuracy]\n",
    "    )\n",
    "\n",
    "model.fit(dataset, epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
